import gym
import torch
import numpy as np
import random, os, wandb
from runner import Runner
# from smac.env import StarCraft2Env
from common.arguments import get_common_args, get_coma_args, get_mixer_args, get_centralv_args, get_reinforce_args, \
    get_commnet_args, get_g2anet_args, get_qplex_args
from env.gym_foo.gym_foo.envs.pac_men import CustomEnv
from env.wrapper_grf_2vs2 import make_football_env_2vs2
from env.wrapper_mpe import make_mpe_env

from env.wrapper_grf_3vs2 import make_football_env_3vs1
from env.wrapper_grf_3_vs_2_full import make_football_env_3vs3_full
from env.wrapper_grf_3_vs_5_full import make_football_env_3vs5_full
from env.wrapper_grf_4vs4 import make_football_env_4vs4
from env.wrapper_grf_3vs3 import make_football_env_3vs3
from env.wrapper_grf_3vs4_BackBall import make_football_env_3vs4_BackBall
from env.wrapper_grf_counterattack import make_football_env_counterattack

from env.wrapper_grf_4vs3 import make_football_env_4vs3
from env.wrapper_grf_4vs5 import make_football_env_4vs5
from env.wrapper_grf_11vs4 import make_football_env_11vs4
from env.wrapper_grf_5vs3 import make_football_env_5vs3
from env.wrapper_grf_5vs5 import make_football_env_5vs5
from env.wrapper_star_MM2 import make_StarCraft_MM2


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(8)


def get_env(args):
    print(args.env)
    GRF=True
    if 'pacmen' in args.env:
        env = CustomEnv(4, args.seed,  args.m,
            args.c, 'tini')
        env.seed(args.seed)
    elif args.env == '2_vs_2':
        # google research football
        env = make_football_env_2vs2(args.seed,
                                     dense_reward=False)
    elif args.env == '3_vs_2':
        # google research football
        env = make_football_env_3vs1(args.seed, dense_reward=False)

    elif args.env == '3_vs_3_full':
        # google research football
        env = make_football_env_3vs3_full(args.seed, dense_reward=False)
    elif args.env == '3_vs_3':
        # google research football
        env = make_football_env_3vs3(args.seed, dense_reward=False)
    elif args.env == '3_vs_4_BackBall':
        # google research football
        env = make_football_env_3vs4_BackBall(args.seed, dense_reward=False)
    elif args.env == '3_vs_5_full':
        # google research football
        env = make_football_env_3vs5_full(args.seed, dense_reward=False)

    elif args.env == '4_vs_4':
        # google research football
        env = make_football_env_4vs4(args.seed, dense_reward=False)

    elif args.env == '4_vs_3':
        # google research football
        env = make_football_env_4vs3(args.seed, dense_reward=False)

    elif args.env == '4_vs_5':
        # google research football
        env = make_football_env_4vs5(args.seed, dense_reward=False)

    elif args.env == '11_vs_4':
        # google research football
        env = make_football_env_11vs4(args.seed, dense_reward=False)
    elif args.env == '5_vs_3':
        # google research football
        env = make_football_env_5vs3(args.seed, dense_reward=False)
    elif args.env == '5_vs_5':
        # google research football
        env = make_football_env_5vs5(args.seed, dense_reward=False)
    elif args.env == 'counterattack':
        env = make_football_env_counterattack(args.seed, dense_reward=False)
    elif 'py' in args.env:
        # google research football
        env = make_mpe_env(args, dense_reward=False)
        GRF=False
    else:
        os.environ['SC2PATH'] = "/home/liuboyin/StarCraftII"
        GRF=False
        env = make_StarCraft_MM2(args)

    return env,GRF
def new_set_args(args):
    args.GPU = "cuda:" + str(random.randint(0,2))  # random.randint(0, 5)
    print(args.GPU)
    if 'pacmen' in args.env:
        args.label = args.label + f'_m{args.m}_c{args.c}'
        args.QPLEX_mixer = 'dmaq'
        args.n_epoch = 70000
    elif args.env == '3s_vs_8z':
        args.n_epoch = 30000
    elif args.env == 'MMM2' or 'm_vs_m' in args.env:
        args.n_epoch = 40000
    elif   args.env=='3s5z_vs_3s6z':
        args.n_epoch = 65000
    elif args.env == '3_vs_2' or args.env=='corridor':
        args.n_epoch = 80000
    elif  args.env == '2_vs_3'or args.env == '12m_vs_14m' or args.env=='3_vs_4_BackBall' :#or args.env =='3_vs_3':
        args.n_epoch = 150000
    elif args.env == '3_vs_3_full' or args.env == '3_vs_5_full' or args.env == 'counterattack'or args.env == '6h_vs_8z':
        args.n_epoch = 200000
    else:
        args.n_epoch = 100000
    args.n_epoch += 20000
    if 'qmix' in args.label or 'qplex' in args.label:
        args.uRNN = False
        args.beta = 0
    if 'debug' in args.label or 'qplex' in args.label:
        args.wandb = False
        args.batch_size = 8
    if 'maven' in args.alg:
        args.cuda=False
    if  'sweep' in args.label:
        args.n_epoch = 13000
    if args.env == '6h_vs_8z':
        args.td_lambda=0.3
        args.epsilon_anneal_time = 500000

    return args

if __name__ == '__main__':
    os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
    args = get_common_args()
    if args.alg.find('coma') > -1:
        args = get_coma_args(args)
    if args.alg.find('qplex') > -1:
        args = get_mixer_args(args)
        args = get_qplex_args(args)
    elif args.alg.find('central_v') > -1:
        args = get_centralv_args(args)
    elif args.alg.find('reinforce') > -1:
        args = get_reinforce_args(args)
    else:
        args = get_mixer_args(args)
    if args.alg.find('commnet') > -1:
        args = get_commnet_args(args)
    if args.alg.find('g2anet') > -1:
        args = get_g2anet_args(args)

    args=new_set_args(args)

    # args.seed = random.randint(0, 1000000)
    setup_seed(args.seed)
    env ,GRF= get_env(args)
    args.GRF=GRF
    print('GRF_Label:',GRF)

    print('--------------', args.env, '------------')
    print('label:', args.label, args.GPU, 'seed:', args.seed)

   #####future MARL
    args.indi_latent_dim=64
    args.evaluate_cycle = 200
    env_info = env.get_env_info()
    args.n_actions = env_info["n_actions"]
    args.n_agents = env_info["n_agents"]
    args.state_shape = env_info["state_shape"]
    args.obs_shape = env_info["obs_shape"]
    args.p_state = env_info["p_state"]
    args.episode_limit = env_info["episode_limit"]
    args.loc_dim = env_info["reltive_loc_dim"] if 'reltive_loc_dim' in env_info else 4
    args.visual_r = env_info["visual_r"] if 'visual_r' in env_info else 1e5
    args.own_feats_dim=env_info["own_feats_dim"] if 'own_feats_dim' in env_info else 1
    args.ally_feats_dim=env_info["ally_feats_dim"] if 'ally_feats_dim' in env_info else (args.n_agents-1)*5

    if args.wandb:
        wandb_name = f'seed_{args.seed}_{args.label}'
        if 'pacmen' in args.env:
            group=args.env+f'_m{args.m}_c{args.c}'
            wandb.init(project=args.project_name, name=wandb_name, group=group, job_type=args.label, config=args)
        else:
            wandb.init(project=args.project_name, name=wandb_name, group=args.env, job_type=args.label, config=args)
        if args.seed == 1256:
            wandb.run.log_code("./", include_fn=lambda path: path.endswith("base_net.py") or path.endswith(
                "arguments.py") or path.endswith("qmix_RNNfuture.py"))
    # 神经网络
    runner = Runner(env, args)
    runner.run(0)
    env.close()

    # if not args.evaluate:
    #     runner.run(0)
    # else:
    #     win_rate, _ = runner.evaluate()
    #     print('The win rate of {} is  {}'.format(args.alg, win_rate))
    #     break
    # env.close()
